-- -- 对偏置卷积的偏移量求导
-- drop function if exists sm_sc.fv_d_conv_2d_dloss_dindepdt_3(float[], int[]);
create or replace function sm_sc.fv_d_conv_2d_dloss_dindepdt_3
(
  i_dloss_ddepdt           float[]           ,                          -- 即已求出的损失函数对 y 的导数矩阵
  i_indepdt_var_len        int[]   default array[1, 1]
)
returns float[]
as
$$
-- declare 
begin
  if array_length(i_indepdt_var_len, 1) = 2
  then 
    return array[array[sm_sc.fv_aggr_slice_sum(i_dloss_ddepdt)]];
    
  elsif array_length(i_indepdt_var_len, 1) = 3
  then 
    return sm_sc.fv_aggr_slice_sum(i_dloss_ddepdt, array[1, array_length(i_dloss_ddepdt, 2), array_length(i_dloss_ddepdt, 3)]);
    
  elsif array_length(i_indepdt_var_len, 1) = 4
  then 
    return sm_sc.fv_aggr_slice_sum(i_dloss_ddepdt, array[1, 1, array_length(i_dloss_ddepdt, 3), array_length(i_dloss_ddepdt, 4)]);
    
  end if;
end
$$
language plpgsql stable
parallel safe
cost 100;
-- -- set search_path to sm_sc;
-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_3
--   (
--     array
--     [
--       [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--     , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--     , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--     , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--     , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--     ] :: float[]
--   )

-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_3
--   (
--     array
--     [
--       [
--         [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--       , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--       , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--       , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--       , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--       ]
--     , [
--         [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--       , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--       , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--       , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--       , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--       ]
--     ]
--   , array[2, 1, 1]
--   )

-- select sm_sc.fv_d_conv_2d_dloss_dindepdt_3
--   (
--     array
--     [
--       [
--         [
--           [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--         ]
--       , [
--           [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--         ]
--       ]
--     , [
--         [
--           [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--         ]
--       , [
--           [1.0,2.0,3.0,4.0,5.0,6.0,7.0]
--         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
--         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
--         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
--         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
--         ]
--       ]
--     ]
--   , array[2, 2, 1, 1]
--   )